import json
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import tiktoken
from call_gpt import call_gpt

# Provided sequence of numbers
# sequence = [
#     85.712, 88.3041, 92.5926, 92.8607, 91.7283, 93.5941, 93.2984, 92.2222, 93.7247, 95.255, 
#     97.6129, 100.362, 101.964, 100.989, 99.1552, 98.7552, 95.5816, 92.1016, 93.0909, 90.6265, 
#     90.889, 93.4845, 90.4035, 87.4337, 89.0199, 87.6601, 84.599, 87.4276, 84.2472, 84.2472, 
#     84.2472, 79.6562, 82.7807, 87.2078, 89.5195, 89.8677, 87.0414, 89.521, 91.3881, 90.4083, 
#     92.1826, 93.4565, 91.3362, 91.4634, 90.5948, 86.6206, 88.9109, 85.1013, 81.9998, 84.546, 
#     88.5653, 89.9294, 91.7537, 93.3198, 89.2661, 91.4882, 92.0078, 87.5357, 83.9893, 87.0735, 
#     85.6814, 88.9019, 87.6985, 88.1173, 89.7935, 88.3208, 86.3542, 87.7306, 85.2927, 89.2897, 
#     87.0666, 88.2948, 90.0, 89.6108, 90.999, 90.1045, 90.8312, 92.7045, 90.5704, 90.1796, 
#     92.3926, 91.6074, 91.122, 92.3337, 87.9201, 88.6984, 92.1454, 91.5105, 92.0245, 91.6564, 
#     91.7597, 92.5793, 94.0613, 95.0294, 95.5366, 95.0073, 95.067, 97.0809, 96.3891, 97.2177
# ]

dataset_dir = '../dataset'
task = 'time_series_prediction'
max_input_length = 100
model = 'gpt-4-vision-preview'

standard_prompt = '''Help me to provide a very rough estimation based on the provided sequence. I will provide a sequence and you will predict the remaining sequence. The sequence is represented by decimal strings separated by commas.

Please continue the following sequence without producing any additional text. Do not say anything like 'the next terms in the sequence are', just return the numbers. Sequence:\n
{sequence}'''

encoding = tiktoken.encoding_for_model(model)

def get_avg_tokens_per_step(input_str, time_sep=','):
    tokens = encoding.encode(input_str)
    input_tokens = len(tokens)
    input_steps = len(input_str.split(time_sep))
    tokens_per_step = input_tokens / input_steps
    return tokens_per_step

mae_list = []
mape_list = []
std_list = []

metadata_path = os.path.join(dataset_dir, task, 'task.json')
with open(metadata_path, 'r', encoding='utf8') as f:
    metadata = json.load(f)

for sid, item in enumerate(tqdm(metadata)):
    input_sequence = item['input'][-max_input_length:]
    # Creating a plot of the sequence
    plt.figure(figsize=(15, 5))
    plt.plot(input_sequence, marker='o', linestyle='-', color='b')
    plt.title("Sequence Analysis")
    plt.xlabel("Sequence Index")
    plt.ylabel("Sequence Value")
    plt.grid(True)
    image_path = f'temp_images/time_series_prediction/{sid}.png'
    plt.savefig(image_path, dpi=300)

    input_str = ','.join(list(map(str, input_sequence))) + ','
    prompt = standard_prompt.format(sequence=input_str)
    avg_tokens = get_avg_tokens_per_step(input_str)
    max_tokens = round(1.4 * avg_tokens)

    prompt = '(Please pay attention to the image content. Give your estimation based on the image.)\n\n' + \
            prompt.rstrip('\n')
    prompt = prompt.rstrip('\n')

    print(prompt)

    # result = call_gpt(prompt, 
    #     model=model, 
    #     image_path=image_path, 
    #     temperature=0,)

    # print(result)

    

    exit(1)
